CART: Classification and Regression Trees
Simulate Data
library(data.table)
set.seed(123456)
n <- 5000
dt <- data.table(
p0 = rep(0.2, n)
, or1 = rep(1, n)
, var1 = sample(c(0, 1), size = n, replace = TRUE, prob = c(0.3, 0.7))
, var1n = rnorm(n, 0, 1)
, or2 = rep(1.1, n)
, var2 = sample(c(0, 1), size = n, replace = TRUE, prob = c(0.4, 0.6))
, var2n = rnorm(n, 0, 2)
, or3 = rep(1.2, n)
, var3 = sample(c(0, 1), size = n, replace = TRUE, prob = c(0.2, 0.8))
, var3n = rnorm(n, 0, 2)
, or4 = rep(1.5, n)
, var4 = sample(c(0, 1), size = n, replace = TRUE, prob = c(0.3, 0.7))
, var4n = rnorm(n, 0, 2)
, or5 = rep(1.7, n)
, var5 = sample(c(0, 1), size = n, replace = TRUE, prob = c(0.5, 0.5))
, var5n = rnorm(n, 0, 2)
, or6 = rep(2, n)
, var6 = sample(c(0, 1), size = n, replace = TRUE, prob = c(0.4, 0.6))
, var6n = rnorm(n, 0, 2)
, or7 = rep(5, n)
, var7 = sample(c(0, 1), size = n, replace = TRUE, prob = c(0.1, 0.9))
, var7n = rnorm(n, 0, 2)
)
dt <- dt[, odds0 := p0 / (1 - p0)
][, log_odds := log(odds0) +
var1 * log(or1) + var1n * log(or1) +
var2 * log(or2) + var2n * log(or2) +
var3 * log(or3) + var3n * log(or3) +
var4 * log(or4) + var4n * log(or4) +
var5 * log(or5) + var5n * log(or5) +
var6 * log(or6) + var6n * log(or6) +
var7 * log(or7) + var7n * log(or7)
][, p := exp(log_odds)/ (1 + exp(log_odds))]
vsample <- function(p){
sample(c(1, 0), size = 1, replace = TRUE, prob = c(p, 1 - p))
}
vsample <- Vectorize(vsample)
dt <- dt[, outcome := vsample(p)]
unique(dt[, .(or1, or2, or3, or4, or5, or6, or7)]) %>% prt(caption = "Variables with Odds Ratios")| or1 | or2 | or3 | or4 | or5 | or6 | or7 |
|---|---|---|---|---|---|---|
| 1 | 1.1 | 1.2 | 1.5 | 1.7 | 2 | 5 |
GLM
m <- glm(outcome ~ var1 + var2 + var3 + var4 + var5 + var6 + var7 +
var1n + var2n + var3n + var4n + var5n + var6n + var7n
, data = dt
, family = binomial
)
library(sjPlot)
tab_model(m)| Â | outcome | ||
|---|---|---|---|
| Predictors | Odds Ratios | CI | p |
| (Intercept) | 0.14 | 0.09 – 0.21 | <0.001 |
| var1 | 1.15 | 0.94 – 1.41 | 0.160 |
| var2 | 1.11 | 0.92 – 1.34 | 0.271 |
| var3 | 1.41 | 1.12 – 1.78 | 0.004 |
| var4 | 1.76 | 1.44 – 2.16 | <0.001 |
| var5 | 1.97 | 1.64 – 2.38 | <0.001 |
| var6 | 2.34 | 1.93 – 2.85 | <0.001 |
| var7 | 5.44 | 4.01 – 7.42 | <0.001 |
| var1n | 1.05 | 0.95 – 1.15 | 0.344 |
| var2n | 1.11 | 1.06 – 1.17 | <0.001 |
| var3n | 1.20 | 1.14 – 1.26 | <0.001 |
| var4n | 1.50 | 1.42 – 1.58 | <0.001 |
| var5n | 1.69 | 1.61 – 1.79 | <0.001 |
| var6n | 2.10 | 1.97 – 2.23 | <0.001 |
| var7n | 5.15 | 4.67 – 5.71 | <0.001 |
| Observations | 5000 | ||
| R2 Tjur | 0.620 | ||
CART: the Full Model
library(rpart)
library(rpart.plot)
library(RColorBrewer)
library(rattle)
predictors <- c(
"var1"
, "var2"
, "var3"
, "var4"
, "var5"
, "var6"
, "var7"
, "var1n"
, "var2n"
, "var3n"
, "var4n"
, "var5n"
, "var6n"
, "var7n"
)
frml <- Wu::wu_formula(outcome = "outcome", predictors = predictors)
set.seed(123456)
tr <- rpart(
frml
, data = dt
, method = "class"
, model = TRUE
, x = TRUE
, y = TRUE
, parms = list(split = "information")
, control = rpart.control(cp = 0
, xval = 20
, maxdepth = 30
, minsplit = 10
, minbucket = 5
)
)
rpart.plot(tr, type = 2, extra = 106, tweak = 2, under = TRUE)Prune the CART Model
bestcp <- tr$cptable[which.min(tr$cptable[,"xerror"]),"CP"]
## bestcp <- tr$cptable[12, "CP"]
trp <- prune(tr, cp = bestcp)
rpart.plot(trp, type = 2, extra = 106, tweak = 2, under = TRUE)AUC on the Training Data
library(pROC)
pred <- predict(object = trp, newdata = dt, type = "prob")
r <- roc(dt$outcome, pred[, 2], ci = TRUE, direction = "<")
plot(r)R sessionInfo
R version 4.1.2 (2021-11-01) Platform: x86_64-pc-linux-gnu (64-bit) Running under: Ubuntu 20.04.3 LTS
Matrix products: default BLAS: /usr/lib/x86_64-linux-gnu/blas/libblas.so.3.9.0 LAPACK: /usr/lib/x86_64-linux-gnu/lapack/liblapack.so.3.9.0
locale: [1] LC_CTYPE=C.UTF-8 LC_NUMERIC=C LC_TIME=C.UTF-8
[4] LC_COLLATE=C.UTF-8 LC_MONETARY=C.UTF-8 LC_MESSAGES=C.UTF-8
[7] LC_PAPER=C.UTF-8 LC_NAME=C LC_ADDRESS=C
[10] LC_TELEPHONE=C LC_MEASUREMENT=C.UTF-8 LC_IDENTIFICATION=C
attached base packages: [1] stats graphics grDevices utils datasets methods base
other attached packages: [1] pROC_1.18.0 rattle_5.5.1 bitops_1.0-7
[4] tibble_3.1.7 RColorBrewer_1.1-3 rpart.plot_3.1.1
[7] rpart_4.1-15 sjPlot_2.8.10 Wu_0.0.0.9000
[10] flexdashboard_0.5.2 lme4_1.1-27.1 Matrix_1.4-0
[13] mgcv_1.8-38 nlme_3.1-152 png_0.1-7
[16] scales_1.2.0 nnet_7.3-16 labelled_2.8.0
[19] kableExtra_1.3.4 plotly_4.9.4.1 gridExtra_2.3
[22] ggplot2_3.3.6 DT_0.18 tableone_0.13.0
[25] magrittr_2.0.3 lubridate_1.7.10 dplyr_1.0.9
[28] plyr_1.8.6 data.table_1.14.2 rmdformats_1.0.2
[31] knitr_1.39
loaded via a namespace (and not attached): [1] insight_0.17.1 webshot_0.5.2 httr_1.4.2 backports_1.2.1
[5] tools_4.1.2 bslib_0.2.5.1 sjlabelled_1.2.0 utf8_1.2.2
[9] R6_2.5.1 DBI_1.1.1 lazyeval_0.2.2 colorspace_2.0-3 [13] withr_2.5.0 tidyselect_1.1.2 emmeans_1.7.4-1 compiler_4.1.2
[17] performance_0.9.0 cli_3.3.0 rvest_1.0.0 xml2_1.3.3
[21] bookdown_0.22 bayestestR_0.12.1 sass_0.4.0 mvtnorm_1.1-3
[25] systemfonts_1.0.2 stringr_1.4.0 digest_0.6.29 minqa_1.2.4
[29] rmarkdown_2.10 svglite_2.1.0 pkgconfig_2.0.3 htmltools_0.5.2
[33] fastmap_1.1.0 highr_0.9 htmlwidgets_1.5.4 rlang_1.0.2
[37] rstudioapi_0.13 jquerylib_0.1.4 generics_0.1.2 jsonlite_1.7.2
[41] crosstalk_1.1.1 parameters_0.18.0 Rcpp_1.0.8.3 munsell_0.5.0
[45] fansi_1.0.3 lifecycle_1.0.1 stringi_1.7.6 yaml_2.3.5
[49] MASS_7.3-54 grid_4.1.2 sjmisc_2.8.9 forcats_0.5.1
[53] crayon_1.5.1 lattice_0.20-45 ggeffects_1.1.2 haven_2.4.1
[57] splines_4.1.2 sjstats_0.18.1 hms_1.1.0 klippy_0.0.0.9500 [61] pillar_1.7.0 boot_1.3-28 estimability_1.3 effectsize_0.7.0 [65] glue_1.6.2 evaluate_0.15 mitools_2.4 modelr_0.1.8
[69] vctrs_0.4.1 nloptr_1.2.2.2 gtable_0.3.0 purrr_0.3.4
[73] tidyr_1.1.3 assertthat_0.2.1 datawizard_0.4.1 xfun_0.31
[77] broom_0.8.0 xtable_1.8-4 survey_4.0 survival_3.2-13
[81] viridisLite_0.4.0 ellipsis_0.3.2